
include("../src/qmcmary.jl")
using ..qmcmary

using LinearAlgebra


function gf1()
    ng = 5
    bmats = [fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1]]

    println(bmats)

    ss = ScrollSVD(bmats)

    println(ss)

    bmats2 = [fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1]]

    push!(ss, bmats2)

    bmatsa = vcat(bmats, bmats2)

    b1 = prod(reverse(bmats))
    b2 = prod(reverse(bmats2))

    println(b2*b1)


    b = Diagonal(ones(4))

    for idx in 5:-1:1
        global b
        b = b*bmats[idx]
    end
    println(b1 - b)


    F = ss.F[end]
    println(b2*b1 - F.U*Diagonal(F.S)*F.Vt)

    F = ss.F[1]
    println(b1 - F.U*Diagonal(F.S)*F.Vt)
    println(F.U * transpose(F.U))
    println(inv(F.U) - transpose(F.U))
    println("vt")
    println(F.Vt * transpose(F.Vt))
    println(inv(F.Vt) - transpose(F.Vt))

    gr = eq_green_scratch(ss)
    gr2 = inv(Diagonal(ones(4)) + b2*b1)

    println(gr-gr2)
end

function gf2()
    Nx = 18
    hk = rand(ComplexF64, Nx, Nx)
    hk += adjoint(hk)
    Ui = 6*ones(Nx)
    the = rand(Nx)*pi
    nx = zeros(Nx)
    ny = sin.(the)#[0.5sqrt(2), 0.5sqrt(2)]#
    nz = cos.(the)
    #
    Nt = 200
    Ng = 5
    bidxs = Vector{Int}(undef, Ng)
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui)
    sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    ##
    siz = (2*Nx, 2*Nx)
    VL = Diagonal(ones(siz[1]))
    DL = VL
    UL = VL
    UR, DR, VR = ss.F[end].U, Diagonal(ss.F[end].S), ss.F[end].Vt
    M = adjoint(UL*UR) + DR*(VR*VL)*DL
    Fm = svd(M, alg=LinearAlgebra.QRIteration())
    gtt = adjoint(Fm.Vt*UL)*inv(Diagonal(Fm.S))*adjoint(UR*Fm.U)
    phase = det(Fm.Vt*UL)*det(UR*Fm.U)
    phase = phase / abs(phase)
    ##
    VL, DL, UL = ss.F[end].U, Diagonal(ss.F[end].S), ss.F[end].Vt
    UR = Diagonal(ones(siz[1]))
    DR = UR
    VR = UR
    #M = adjoint(UL*UR) + DR*(VR*VL)*DL
    ## DRS (DRS^-1 adjoint(UL*UR) DLS^-1 + DRB*(VR*VL)*DLB) DLS
    DLS = zeros(Float64, siz[1], siz[1])
    DLB = zeros(Float64, siz[1], siz[1])
    for i in Base.OneTo(siz[1])
        if DL[i, i] > 1.0
            DLB[i, i] = DL[i, i]
            DLS[i, i] = 1.0
        else
            DLS[i, i] = DL[i, i]
            DLB[i, i] = 1.0
        end
    end
    println(diag(DL))
    println(diag(DLB*DLS))
    M = (adjoint(UL*UR)*inv(DLS) + DR*(VR*VL)*DLB)*DLS
    Fm = svd(M, alg=LinearAlgebra.QRIteration())
    gtt2 = adjoint(Fm.Vt*UL)*inv(Diagonal(Fm.S))*adjoint(UR*Fm.U)
    #gtt2 = adjoint(UL)*inv(DLS)*adjoint(Fm.Vt)*inv(Diagonal(Fm.S))*adjoint(UR*Fm.U)
    phase2 = det(Fm.Vt*UL)*det(UR*Fm.U)
    phase2 = phase2 / abs(phase2)
    println(diag(gtt))
    println(diag(gtt2))
    println(maximum(abs.(gtt2 - gtt)))
    println(phase, " ", phase2)
    ##
    scrollR2L(ss, allbmats[Nt-Ng+1:Nt])
    gt, ph = eq_green_scratch(ss)
    println(diag(gt))
    println(ph)
end


gf2()
